Compare simple difference in functional effects across two conditions¶

Import Python modules. We use polyclonal for the plotting:

In [1]:
import itertools

import altair as alt

import dms_variants.utils

import pandas as pd

import polyclonal
import polyclonal.plot

This notebook is parameterized by papermill. The next cell is tagged as parameters to get the passed parameters.

In [2]:
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
diffs_csv = None
chart_html = None
params = None
In [3]:
# Parameters
params = {
    "condition_1": {
        "name": 220210,
        "selections": ["LibA-220210-293T_ACE2-1", "LibA-220210-293T_ACE2-2"],
    },
    "condition_2": {
        "name": 220302,
        "selections": ["LibA-220302-293T_ACE2-1", "LibA-220302-293T_ACE2-2"],
    },
    "avg_method": "median",
    "per_selection_tooltips": True,
    "plot_kwargs": {
        "addtl_slider_stats": {
            "times_seen": 3,
            "difference_std": 2,
            "fraction_pairs_w_mutation": 1,
            "best_effect": -2,
            "220210 effect": None,
            "220302 effect": None,
        },
        "addtl_slider_stats_hide_not_filter": [
            "best_effect",
            "220210 effect",
            "220302 effect",
        ],
        "addtl_slider_stats_as_max": ["difference_std"],
        "heatmap_max_at_least": 1,
        "heatmap_min_at_least": -1,
        "init_floor_at_zero": False,
        "init_site_statistic": "mean_abs",
        "site_zoom_bar_color_col": "region",
        "slider_binding_range_kwargs": {"times_seen": {"step": 1, "min": 1, "max": 25}},
    },
    "title": "Differences in mutation effects on 293T entry across days.",
    "legend": "Interactive plot of the difference of effects of mutations between two conditions.\n\nUse the site zoom bar at the top to zoom in on specific sites. The line plot shows a summary\nstatistic indicating the typical effects of mutations at each site. The heat map shows the effects of\nindividual mutations, with parental amino-acid identities indicated by x, dark gray indicating\nmutations that fail a slider filter, and light gray indicating non-measured mutations.\n\nYou can mouse over points to get details about individual measurements, including measurements\n in individual selection experiments.\n\nThe options at the bottom of the plot let you modify the display, such as by selecting how\nmany different variants a mutation must be seen in to be shown (*minimum times_seen*),\nthe fraction of different pairwise comparisons the mutation was measured in, the maximum\nstandard deviation for the difference among pairwise comparison, or its\nbest effect in any condition.\n",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
diffs_csv = "results/func_effect_diffs/220210_vs_220302_comparison_diffs.csv"
chart_html = "results/func_effect_diffs/220210_vs_220302_comparison_diffs_nolegend.html"

Read the input data:

In [4]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)
assert site_numbering_map[["site", "sequential_site"]].notnull().all().all()
addtl_site_cols = [
    c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]

condition_1 = params["condition_1"]["name"]
condition_2 = params["condition_2"]["name"]
assert condition_1 != condition_2, f"{condition_1=}, {condition_2=}"
condition_1_selections = params["condition_1"]["selections"]
condition_2_selections = params["condition_2"]["selections"]
assert len(condition_1_selections) == len(set(condition_1_selections))
assert len(condition_2_selections) == len(set(condition_2_selections))
assert len(condition_1_selections), params["condition_1"]
assert len(condition_2_selections), params["condition_2"]
if set(condition_1_selections).intersection(condition_2_selections):
    raise ValueError(
        f"shared selections in {condition_1_selections=} and {condition_2_selections=}"
    )

dfs = []
for c, sels in [
    (condition_1, condition_1_selections),
    (condition_2, condition_2_selections),
]:
    for s in sels:
        dfs.append(
            pd.read_csv(
                f"results/func_effects/by_selection/{s}_func_effects.csv"
            ).assign(
                selection=s,
                condition=c,
                times_seen=lambda x: x["times_seen"].astype("Int64"),
                mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
            )
        )
func_effects = pd.concat(dfs, ignore_index=True)

Correlations among all selections¶

Compute the correlations in the mutation effects across all selections:

In [5]:
# We compute for several times seen values, get those:
try:
    init_times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
    print("No times seen in params, using a value of 3")
    init_times_seen = 3

# do analysis for each "times_seen"
func_effects_for_corr = pd.concat(
    [
        func_effects.query("times_seen >= @t", engine="python").assign(min_times_seen=t)
        for t in [1, init_times_seen, 2 * init_times_seen]
    ]
)

corrs = (
    dms_variants.utils.tidy_to_corr(
        df=func_effects_for_corr,
        sample_col="selection",
        label_col="mutation",
        value_col="functional_effect",
        group_cols=["min_times_seen"],
    )
    .assign(
        r2=lambda x: x["correlation"] ** 2,
        min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
    )
    .rename(columns={"correlation": "r"})
)

corr_chart = (
    alt.Chart(corrs)
    .encode(
        alt.X("selection_1", title=None),
        alt.Y("selection_2", title=None),
        column=alt.Column("min_times_seen", title=None),
        color=alt.Color("r2", scale=alt.Scale(zero=True)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if c in {"r2", "r"} else c
            for c in ["selection_1", "selection_2", "r2", "r"]
        ],
    )
    .mark_rect(stroke="black")
    .properties(
        width=alt.Step(15),
        height=alt.Step(15),
        title="Per-selection correlation in mutation functional effects",
    )
    .configure_axis(labelLimit=500)
)

display(corr_chart)

print(
    f"\nSelections for {condition_1}: {condition_1_selections}\n"
    f"Selections for {condition_2}: {condition_2_selections}\n"
)
Selections for 220210: ['LibA-220210-293T_ACE2-1', 'LibA-220210-293T_ACE2-2']
Selections for 220302: ['LibA-220302-293T_ACE2-1', 'LibA-220302-293T_ACE2-2']

Average functional effects for each condition¶

Average the functional effects for each condition using the specified averaging method, then print the correlation between these average functional effects at several times seen:

In [6]:
avg_method = params["avg_method"]
assert avg_method in {"mean", "median"}, avg_method

avg_func_effects = (
    func_effects.groupby(
        ["condition", "site", "wildtype", "mutant", "mutation"], as_index=False
    )
    .aggregate(
        effect=pd.NamedAgg("functional_effect", avg_method),
        times_seen=pd.NamedAgg("times_seen", "sum"),
        n_selections=pd.NamedAgg("site", "count"),
    )
    .assign(
        times_seen=lambda x: (x["times_seen"] / x["n_selections"]).where(
            x["mutant"] != x["wildtype"],
            pd.NA,
        )
    )
)

avg_func_effects_for_corr = pd.concat(
    [
        avg_func_effects.query("times_seen >= @t", engine="python").assign(
            min_times_seen=t
        )
        for t in [1, init_times_seen, 2 * init_times_seen]
    ]
)
print("Correlation between average functional effects across conditions:")
display(
    dms_variants.utils.tidy_to_corr(
        df=avg_func_effects_for_corr,
        sample_col="condition",
        label_col="mutation",
        value_col="effect",
        group_cols=["min_times_seen"],
    )
    .assign(
        r2=lambda x: x["correlation"] ** 2,
        min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
    )
    .rename(columns={"correlation": "r"})
    .query("condition_1 != condition_2")
    .reset_index(drop=True)
    .groupby("min_times_seen")
    .first()
    .round(3)
)
Correlation between average functional effects across conditions:
condition_1 condition_2 r r2
min_times_seen
min times seen 1 220302 220210 0.867 0.752
min times seen 3 220302 220210 0.879 0.773
min times seen 6 220302 220210 0.848 0.718

Compute pairwise differences¶

Compute pairwise differences in effects between all pairs of condition 1 selections versus condition 2 selections. For each comparison, we compute the times seen as the mean between the two selections being compared.

We then compute the average (using the specified average method) difference across comparisons, the mean times seen, and the fraction of comparisons in which a difference can be computed:

In [7]:
# compute differences for all individual pairs
diffs = []
for sel1, sel2 in itertools.product(condition_1_selections, condition_2_selections):
    df1 = func_effects.query("selection == @sel1")[
        ["wildtype", "site", "mutant", "times_seen", "functional_effect"]
    ]
    df2 = func_effects.query("selection == @sel2")[
        ["wildtype", "site", "mutant", "times_seen", "functional_effect"]
    ]
    diffs.append(
        df1.merge(df2, on=["wildtype", "site", "mutant"], validate="1:1").assign(
            times_seen=lambda x: (x["times_seen_x"] + x["times_seen_y"]) / 2,
            difference=lambda x: x["functional_effect_x"] - x["functional_effect_y"],
        )[["wildtype", "site", "mutant", "times_seen", "difference"]]
    )

# compute average differences across pairs
diffs = (
    pd.concat(diffs, ignore_index=True)
    .groupby(["wildtype", "site", "mutant"], as_index=False)
    .aggregate(
        difference=pd.NamedAgg("difference", avg_method),
        difference_std=pd.NamedAgg("difference", "std"),
        times_seen=pd.NamedAgg("times_seen", "mean"),
        fraction_pairs_w_mutation=pd.NamedAgg(
            "difference",
            lambda s: len(s)
            / (len(condition_1_selections) * len(condition_2_selections)),
        ),
    )
)

# add other relevant stuff to data frame of differences
diffs = (
    diffs
    # add average effects in each condition
    .merge(
        avg_func_effects.pivot_table(
            index=["site", "wildtype", "mutant"],
            values="effect",
            columns="condition",
        )
        .reset_index()
        .assign(best_effect=lambda x: x[[condition_1, condition_2]].max(axis=1))
        .rename(columns={c: f"{c} effect" for c in [condition_1, condition_2]}),
        on=["wildtype", "site", "mutant"],
        validate="one_to_one",
    )
    # add per-selection effects (times seen)
    .merge(
        func_effects.assign(
            effect_times_seen=lambda x: (
                x["functional_effect"].map(lambda e: f"{e:.2f}")
                + (" (" + x["times_seen"].astype(str) + ")").where(
                    x["mutant"] != x["wildtype"],
                    "",
                )
            )
        )
        .pivot_table(
            index=[
                "site",
                "wildtype",
                "mutant",
            ],
            values="effect_times_seen",
            columns="selection",
            aggfunc=lambda s: ",".join(s),
        )[condition_1_selections + condition_2_selections]
        .reset_index(),
        on=["wildtype", "site", "mutant"],
        validate="one_to_one",
    )
    # sort values
    .sort_values(["site", "mutant"]).reset_index(drop=True)
)

print(f"Writing differences to {diffs_csv}")
diffs.to_csv(diffs_csv, index=False, float_format="%.4g")
Writing differences to results/func_effect_diffs/220210_vs_220302_comparison_diffs.csv

Make a correlation plot¶

Make a correlation plot between the two conditions with informative tooltips and slider bars:

In [8]:
mutation_selection = alt.selection_point(
    on="mouseover", fields=["mutation"], empty=False
)

corr_diffs = (
    diffs.query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
    )
    .drop(columns=["wildtype", "site", "mutant"])
)
corr_diffs = corr_diffs[
    ["mutation"] + [c for c in corr_diffs.columns if c != "mutation"]
]

sliders = {
    stat: alt.param(
        value=(
            params["plot_kwargs"]["addtl_slider_stats"][stat]
            if (
                "addtl_slider_stats" in params["plot_kwargs"]
                and stat in params["plot_kwargs"]["addtl_slider_stats"]
            )
            else (
                corr_diffs[stat].max()
                if stat == "difference_std"
                else corr_diffs[stat].min()
            )
        ),
        bind=alt.binding_range(
            name=f"maximum {stat}" if stat == "difference_std" else f"minimum {stat}",
            min=corr_diffs[stat].min(),
            max=corr_diffs[stat].max(),
        ),
    )
    for stat in [
        "times_seen",
        "best_effect",
        "difference_std",
        "fraction_pairs_w_mutation",
        f"{condition_1} effect",
        f"{condition_2} effect",
    ]
}

corr_chart = (
    alt.Chart(corr_diffs)
    .add_params(mutation_selection)
    .encode(
        alt.X(
            f"{condition_1} effect", scale=alt.Scale(nice=False, zero=False, padding=3)
        ),
        alt.Y(
            f"{condition_2} effect", scale=alt.Scale(nice=False, zero=False, padding=3)
        ),
        strokeWidth=alt.condition(mutation_selection, alt.value(2), alt.value(0)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if corr_diffs[c].dtype == float else c
            for c in corr_diffs.columns
        ],
    )
    .mark_circle(fill="black", fillOpacity=0.30, size=30, stroke="red")
    .properties(width=225, height=225)
    .configure_axis(grid=False)
)

for stat, slider in sliders.items():
    if stat == "difference_std":
        corr_chart = corr_chart.add_params(slider).transform_filter(
            alt.datum[stat] <= slider
        )
    else:
        corr_chart = corr_chart.add_params(slider).transform_filter(
            alt.datum[stat] >= slider
        )

corr_chart
Out[8]:

Make interactive chart¶

Set up keyword arguments to https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap if they are not already specified:

In [9]:
plot_kwargs = params["plot_kwargs"].copy()

if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {}

if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 3

if "difference_std" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["difference_std"] = diffs["difference_std"].max()
    if "addtl_slider_stats_as_max" not in plot_kwargs:
        plot_kwargs["addtl_slider_stats_as_max"] = ["difference_std"]
    else:
        plot_kwargs["addtl_slider_stats_as_max"].append("difference_std")
elif "addtl_slider_stats_as_max" not in plot_kwargs:
    raise ValueError(
        "You specified `difference_std` in `addtl_slider_stats` but did not add it to "
        "`addtl_slider_stats_as_max`. If you really do not want `difference_std` in "
        "`addtl_slider_stats_as_max`, then specify that list without it."
    )

if "fraction_pairs_w_mutation" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["fraction_pairs_w_mutation"] = 0.5

if "site_zoom_bar_color_col" in plot_kwargs:
    if plot_kwargs["site_zoom_bar_color_col"] in diffs.columns:
        pass
    elif plot_kwargs["site_zoom_bar_color_col"] in site_numbering_map.columns:
        diffs = diffs.merge(
            site_numbering_map[["site", plot_kwargs["site_zoom_bar_color_col"]]],
            on="site",
            validate="many_to_one",
            how="left",
        )

if "addtl_tooltip_stats" not in plot_kwargs:
    plot_kwargs["addtl_tooltip_stats"] = []
for c in ["difference_std"] + addtl_site_cols:
    if c not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append(c)

if "sequential_site" not in diffs.columns:
    diffs = diffs.merge(
        site_numbering_map[["site", *addtl_site_cols]],
        on="site",
        validate="many_to_one",
        how="left",
    )
if any(diffs["site"] != diffs["sequential_site"]):
    if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

if params["per_selection_tooltips"]:
    assert set(condition_1_selections + condition_2_selections).issubset(diffs.columns)
    plot_kwargs["addtl_tooltip_stats"] += [
        s
        for s in condition_1_selections + condition_2_selections
        if s not in plot_kwargs["addtl_tooltip_stats"]
    ]

if "alphabet" not in plot_kwargs:
    plot_kwargs["alphabet"] = [
        a
        for a in polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP)
        if a in set(diffs["mutant"])
    ]

if "sites" not in plot_kwargs:
    plot_kwargs["sites"] = site_numbering_map.sort_values("sequential_site")[
        "site"
    ].tolist()

Now make the interactive heatmap:

In [10]:
assert "_dummy" not in diffs.columns

chart = polyclonal.plot.lineplot_and_heatmap(
    data_df=diffs.assign(_dummy="dummy"),
    stat_col="difference",
    category_col="_dummy",
    **plot_kwargs,
)

display(chart)

print(f"\nSaving to {chart_html}")
chart.save(chart_html)
Saving to results/func_effect_diffs/220210_vs_220302_comparison_diffs_nolegend.html
In [ ]: